{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SAC的简化版实现,alpha使用常量代替.只使用一个value模型,而不是两个."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbiUlEQVR4nO3df2xT97038Lcd/8hPOyQQuxHJyBW0NOLH1tAGr5OqZ2SkNOvakj7qEJdGHU8rqEFAKjSyAdWqSeGhV+vK1tJK04CrK5op1dIOBu3yBBpW4UIIZIRfWacLTQTYKdDYSSC2Y3+eP2jOxZDSOMT+2sn7JR2JnPO1/T4Jfuv4/LJORARERHGmVx2AiCYmlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESmhrHzeeustTJs2DampqSgtLcXRo0dVRSEiBZSUz5/+9CdUV1fj1VdfxfHjxzF37lyUl5eju7tbRRwiUkCn4sLS0tJSPPzww/j9738PAAiHwygoKMDq1auxYcOGb318OBzGpUuXkJWVBZ1OF+u4RDRCIoLe3l7k5+dDr7/7to0hTpk0gUAAra2tqKmp0ebp9XqUlZXB5XIN+xi/3w+/36/9fPHiRRQXF8c8KxGNTldXF6ZOnXrXMXEvnytXriAUCsFms0XMt9lsOHfu3LCPqa2txa9+9as75nd1dcFiscQkJxFFz+fzoaCgAFlZWd86Nu7lMxo1NTWorq7Wfh5aQYvFwvIhSkAj2R0S9/KZPHkyUlJS4PF4IuZ7PB7Y7fZhH2M2m2E2m+MRj4jiJO5Hu0wmE0pKStDU1KTNC4fDaGpqgsPhiHccIlJEyceu6upqVFVVYd68eXjkkUfw29/+Fv39/XjhhRdUxCEiBZSUz3PPPYcvv/wSmzdvhtvtxne/+1189NFHd+yEJqLxS8l5PvfK5/PBarXC6/VyhzNRAonmvclru4hICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiWiLp9Dhw7hySefRH5+PnQ6HT744IOI5SKCzZs347777kNaWhrKysrw+eefR4y5du0ali5dCovFguzsbCxfvhx9fX33tCJElFyiLp/+/n7MnTsXb7311rDLt27dim3btuGdd97BkSNHkJGRgfLycgwMDGhjli5ditOnT6OxsRF79+7FoUOH8NJLL41+LYgo+cg9ACANDQ3az+FwWOx2u7z++uvavJ6eHjGbzfLee++JiMiZM2cEgLS0tGhj9u/fLzqdTi5evDii1/V6vQJAvF7vvcQnojEWzXtzTPf5nD9/Hm63G2VlZdo8q9WK0tJSuFwuAIDL5UJ2djbmzZunjSkrK4Ner8eRI0eGfV6/3w+fzxcxEVFyG9PycbvdAACbzRYx32azacvcbjfy8vIilhsMBuTk5GhjbldbWwur1apNBQUFYxmbiBRIiqNdNTU18Hq92tTV1aU6EhHdozEtH7vdDgDweDwR8z0ej7bMbreju7s7Yvng4CCuXbumjbmd2WyGxWKJmIgouY1p+RQVFcFut6OpqUmb5/P5cOTIETgcDgCAw+FAT08PWltbtTEHDhxAOBxGaWnpWMYhogRmiPYBfX19+Ne//qX9fP78ebS1tSEnJweFhYVYu3Ytfv3rX2PGjBkoKirCpk2bkJ+fj6effhoA8OCDD+Lxxx/Hiy++iHfeeQfBYBCrVq3CT3/6U+Tn54/ZihFRgov2UNrBgwcFwB1TVVWViNw83L5p0yax2WxiNptlwYIF0tHREfEcV69elSVLlkhmZqZYLBZ54YUXpLe3d8QZeKidKDFF897UiYgo7L5R8fl8sFqt8Hq93P9DlECieW8mxdEuIhp/WD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlIiqfGpra/Hwww8jKysLeXl5ePrpp9HR0RExZmBgAE6nE7m5ucjMzERlZSU8Hk/EmM7OTlRUVCA9PR15eXlYv349BgcH731tiChpRFU+zc3NcDqd+Oyzz9DY2IhgMIiFCxeiv79fG7Nu3Trs2bMH9fX1aG5uxqVLl7B48WJteSgUQkVFBQKBAA4fPoxdu3Zh586d2Lx589itFRElPrkH3d3dAkCam5tFRKSnp0eMRqPU19drY86ePSsAxOVyiYjIvn37RK/Xi9vt1sZs375dLBaL+P3+Eb2u1+sVAOL1eu8lPhGNsWjem/e0z8fr9QIAcnJyAACtra0IBoMoKyvTxsycOROFhYVwuVwAAJfLhdmzZ8Nms2ljysvL4fP5cPr06WFfx+/3w+fzRUxElNxGXT7hcBhr167Fo48+ilmzZgEA3G43TCYTsrOzI8babDa43W5tzK3FM7R8aNlwamtrYbVatamgoGC0sYkoQYy6fJxOJ06dOoW6urqxzDOsmpoaeL1eberq6or5axJRbBlG86BVq1Zh7969OHToEKZOnarNt9vtCAQC6Onpidj68Xg8sNvt2pijR49GPN/Q0bChMbczm80wm82jiUpECSqqLR8RwapVq9DQ0IADBw6gqKgoYnlJSQmMRiOampq0eR0dHejs7ITD4QAAOBwOtLe3o7u7WxvT2NgIi8WC4uLie1kXIkoiUW35OJ1O7N69Gx9++CGysrK0fTRWqxVpaWmwWq1Yvnw5qqurkZOTA4vFgtWrV8PhcGD+/PkAgIULF6K4uBjLli3D1q1b4Xa7sXHjRjidTm7dEE0k0RxGAzDstGPHDm3MjRs35OWXX5ZJkyZJenq6PPPMM3L58uWI57lw4YIsWrRI0tLSZPLkyfLKK69IMBgccQ4eaidKTNG8N3UiIuqqb3R8Ph+sViu8Xi8sFovqOET0tWjem7y2i4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlBjVVe1EY0HCYYT6+xEOBKBLSUFKWhp0JhN0Op3qaBQHLB+KOxFB8No1fLl/P7xHjyJw5Qr0ZjPSp09H3hNPIGvOHOhSUlTHpBhj+VBciQj8ly7hwptvor+jA/j60sJQXx+8V6+ir70d+c8/jynl5SygcY77fCiuQtevo/Pdd9F/7hwkHMZXfj+OXbmCz30+hEUQun4dF//zP+E9fhxJeM0zRYFbPhRX3pYW9J48CRFBZ38/Np04gQ6vFxkGA/7P/ffjuaIi4Pp1eBoakDVrFlLS0lRHphjhlg/FjYTD6D15EgiHIQD+b3s7zvT0ICQCXzCI3589i1NffQUAuP7f/43QjRtqA1NMsXwoboLXrsHb0qL97AsGI5YHwmH4Q6F4xyJFWD4UN4P9/Qj7/QAAHYD/ZbfDcMth9fstFnwnM1NROoo37vOhuAlfvw4JhwEAOp0OVdOnI8toxP+7fBn3paXhxfvvR15qquKUFC8sH4qb0PXr2qF1ADDo9fjf06bh2WnTMLT9wxMMJw6WD8XN9fPnIbft59HpdBiubgyZmTzPZ5zjPh+Km/DAwIjHpk6dihR+BBvXWD6UkPRpaYCe/z3HM/51KS6iPVs5JS2NH7vGOZYPxYeIdph9JPRGI8Cdz+May4fiQwThaM5Y1ul45GucY/lQXEg4fPNQO9HXWD4UFxIKwd/dPfIHcGfzuMe/MMWFDA4iMNLy0euRcf/9sQ1EyrF8KPHodDBYLKpTUIyxfCjh6HQ6pKSnq45BMcbyobiQUCjiuq5vw/IZ/1g+FBfhGze0K9pHgicYjn8sH4qL0MAAEEX50PjH8qG4CPX13fzoRfQ1lg/FxY0vvhjxVe3GSZOg5xXt4x7Lh+Iimq0eY24uv7ViAoiqfLZv3445c+bAYrHAYrHA4XBg//792vKBgQE4nU7k5uYiMzMTlZWV8Hg8Ec/R2dmJiooKpKenIy8vD+vXr8fg4ODYrA2NC3qzmTucJ4Coymfq1KnYsmULWltbcezYMfzwhz/EU089hdOnTwMA1q1bhz179qC+vh7Nzc24dOkSFi9erD0+FAqhoqICgUAAhw8fxq5du7Bz505s3rx5bNeKEoqIRHWYXZ+aCp2BN9kc73Ryj18LmZOTg9dffx3PPvsspkyZgt27d+PZZ58FAJw7dw4PPvggXC4X5s+fj/379+PHP/4xLl26BJvNBgB455138POf/xxffvklTCbTsK/h9/vhv+V2DD6fDwUFBfB6vbDwTNiEJyJw19fj0n/914jGT/rBD1C0fj2vak9CPp8PVqt1RO/NUe/zCYVCqKurQ39/PxwOB1pbWxEMBlFWVqaNmTlzJgoLC+FyuQAALpcLs2fP1ooHAMrLy+Hz+bStp+HU1tbCarVqU0FBwWhjkwoiCPX3R/UQFs/4F3X5tLe3IzMzE2azGStWrEBDQwOKi4vhdrthMpmQnZ0dMd5ms8HtdgMA3G53RPEMLR9a9k1qamrg9Xq1qaurK9rYpNLX38FOdKuoP1g/8MADaGtrg9frxfvvv4+qqio0NzfHIpvGbDbDbDbH9DUodiQcxo0LF0Y8nofZJ4aoy8dkMmH69OkAgJKSErS0tODNN9/Ec889h0AggJ6enoitH4/HA7vdDgCw2+04evRoxPMNHQ0bGkPjkAgGvd4RD8+YMSOGYShR3PN5PuFwGH6/HyUlJTAajWhqatKWdXR0oLOzEw6HAwDgcDjQ3t6O7lvu69LY2AiLxYLi4uJ7jULjRAq/MnlCiGrLp6amBosWLUJhYSF6e3uxe/dufPLJJ/j4449htVqxfPlyVFdXIycnBxaLBatXr4bD4cD8+fMBAAsXLkRxcTGWLVuGrVu3wu12Y+PGjXA6nfxYRTfxdhoTRlTl093djeeffx6XL1+G1WrFnDlz8PHHH+NHP/oRAOCNN96AXq9HZWUl/H4/ysvL8fbbb2uPT0lJwd69e7Fy5Uo4HA5kZGSgqqoKr7322tiuFSUUCYWi+uqclIyMGKahRHHP5/moEM25BKRe8KuvcPaVVxC8cuXbB+t0mPkf/8H9PkkqLuf5EI1UeGAAwkto6DYsH4q5oM+HcCCgOgYlGJYPxZzf7UZ4hCcZ6lNToefBhwmB5UMJxTRlCoy3nSVP4xPLhxKK3mSCzmhUHYPigOVDMSUiUd27WW82s3wmCJYPxVz4xo0Rj9UZDLyR2ATB8qGYG4zminbeSmPCYPlQzIX6+lRHoATE8qHYEkH/uXMjHq7T87/kRMG/NMVcaIRfmQMAGQ88EMMklEhYPpRQDDzHZ8Jg+VBCMfCK9gmD5UMxJaFQVOf58F4+EwfLh2Iq7PcjHAyOeLzOYOA3V0wQLB+KqXAgwNtp0LBYPhRTwa++Qqi3V3UMSkAsH4qpwZ6eEX9nlz49HYasrBgnokTB8qGEYcjM5KH2CYTlQwlDZzRCbzKpjkFxwvKhhKFn+UwoLB9KGGn/9m+8l88EwvKhmDLm5IzoG0j1aWmYUl7OC0snEP6lKabSCgsx6Qc/uPt9enQ6ZJeWIp3f1TWhsHwotvR65C9ZAktJyTcWUNbs2cj/93/nHQwnmKi+LpkoWjqdDobsbExbvRqehgZ8dfgwAlevAgCMkyYhu7QU9spKGHNzeVnFBMOvS6a4EBFABMGrV2+WjwiMOTkwTZkC6HQsnnEimvcmt3woLnQ6HaDTwTRlys3CoQmP+3yISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlLinspny5Yt0Ol0WLt2rTZvYGAATqcTubm5yMzMRGVlJTweT8TjOjs7UVFRgfT0dOTl5WH9+vUY5H1+iSaUUZdPS0sL3n33XcyZMydi/rp167Bnzx7U19ejubkZly5dwuLFi7XloVAIFRUVCAQCOHz4MHbt2oWdO3di8+bNo18LIko+Mgq9vb0yY8YMaWxslMcee0zWrFkjIiI9PT1iNBqlvr5eG3v27FkBIC6XS0RE9u3bJ3q9XtxutzZm+/btYrFYxO/3D/t6AwMD4vV6tamrq0sAiNfrHU18IooRr9c74vfmqLZ8nE4nKioqUFZWFjG/tbUVwWAwYv7MmTNRWFgIl8sFAHC5XJg9ezZsNps2pry8HD6fD6dPnx729Wpra2G1WrWpoKBgNLGJKIFEXT51dXU4fvw4amtr71jmdrthMpmQfdtNwG02G9xutzbm1uIZWj60bDg1NTXwer3a1NXVFW1sIkowUV1Y2tXVhTVr1qCxsRGpqamxynQHs9kMs9kct9cjotiLasuntbUV3d3deOihh2AwGGAwGNDc3Ixt27bBYDDAZrMhEAigp6cn4nEejwd2ux0AYLfb7zj6NfTz0BgiGv+iKp8FCxagvb0dbW1t2jRv3jwsXbpU+7fRaERTU5P2mI6ODnR2dsLhcAAAHA4H2tvb0d3drY1pbGyExWJBcXHxGK0WESW6qD52ZWVlYdasWRHzMjIykJubq81fvnw5qqurkZOTA4vFgtWrV8PhcGD+/PkAgIULF6K4uBjLli3D1q1b4Xa7sXHjRjidTn60IppAxvxmYm+88Qb0ej0qKyvh9/tRXl6Ot99+W1uekpKCvXv3YuXKlXA4HMjIyEBVVRVee+21sY5CRAmMt1ElojETzXuT13YRkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpYVAdYDREBADg8/kUJyGiWw29J4feo3eTlOVz9epVAEBBQYHiJEQ0nN7eXlit1ruOScryycnJAQB0dnZ+6womGp/Ph4KCAnR1dcFisaiOM2LMHV/JmltE0Nvbi/z8/G8dm5Tlo9ff3FVltVqT6g9zK4vFkpTZmTu+kjH3SDcIuMOZiJRg+RCREklZPmazGa+++irMZrPqKFFL1uzMHV/JmjsaOhnJMTEiojGWlFs+RJT8WD5EpATLh4iUYPkQkRIsHyJSIinL56233sK0adOQmpqK0tJSHD16VGmeQ4cO4cknn0R+fj50Oh0++OCDiOUigs2bN+O+++5DWloaysrK8Pnnn0eMuXbtGpYuXQqLxYLs7GwsX74cfX19Mc1dW1uLhx9+GFlZWcjLy8PTTz+Njo6OiDEDAwNwOp3Izc1FZmYmKisr4fF4IsZ0dnaioqIC6enpyMvLw/r16zE4OBiz3Nu3b8ecOXO0s38dDgf279+f0JmHs2XLFuh0Oqxduzbpso8JSTJ1dXViMpnkj3/8o5w+fVpefPFFyc7OFo/HoyzTvn375Je//KX8+c9/FgDS0NAQsXzLli1itVrlgw8+kH/84x/yk5/8RIqKiuTGjRvamMcff1zmzp0rn332mfz973+X6dOny5IlS2Kau7y8XHbs2CGnTp2StrY2eeKJJ6SwsFD6+vq0MStWrJCCggJpamqSY8eOyfz58+X73/++tnxwcFBmzZolZWVlcuLECdm3b59MnjxZampqYpb7L3/5i/z1r3+Vf/7zn9LR0SG/+MUvxGg0yqlTpxI28+2OHj0q06ZNkzlz5siaNWu0+cmQfawkXfk88sgj4nQ6tZ9DoZDk5+dLbW2twlT/4/byCYfDYrfb5fXXX9fm9fT0iNlslvfee09ERM6cOSMApKWlRRuzf/9+0el0cvHixbhl7+7uFgDS3Nys5TQajVJfX6+NOXv2rAAQl8slIjeLV6/Xi9vt1sZs375dLBaL+P3+uGWfNGmS/OEPf0iKzL29vTJjxgxpbGyUxx57TCufZMg+lpLqY1cgEEBrayvKysq0eXq9HmVlZXC5XAqTfbPz58/D7XZHZLZarSgtLdUyu1wuZGdnY968edqYsrIy6PV6HDlyJG5ZvV4vgP+5a0BrayuCwWBE9pkzZ6KwsDAi++zZs2Gz2bQx5eXl8Pl8OH36dMwzh0Ih1NXVob+/Hw6HIykyO51OVFRURGQEkuP3PZaS6qr2K1euIBQKRfziAcBms+HcuXOKUt2d2+0GgGEzDy1zu93Iy8uLWG4wGJCTk6ONibVwOIy1a9fi0UcfxaxZs7RcJpMJ2dnZd80+3LoNLYuV9vZ2OBwODAwMIDMzEw0NDSguLkZbW1vCZgaAuro6HD9+HC0tLXcsS+TfdywkVflQ7DidTpw6dQqffvqp6igj8sADD6CtrQ1erxfvv/8+qqqq0NzcrDrWXXV1dWHNmjVobGxEamqq6jjKJdXHrsmTJyMlJeWOvf8ejwd2u11RqrsbynW3zHa7Hd3d3RHLBwcHce3atbis16pVq7B3714cPHgQU6dO1ebb7XYEAgH09PTcNftw6za0LFZMJhOmT5+OkpIS1NbWYu7cuXjzzTcTOnNrayu6u7vx0EMPwWAwwGAwoLm5Gdu2bYPBYIDNZkvY7LGQVOVjMplQUlKCpqYmbV44HEZTUxMcDofCZN+sqKgIdrs9IrPP58ORI0e0zA6HAz09PWhtbdXGHDhwAOFwGKWlpTHLJiJYtWoVGhoacODAARQVFUUsLykpgdFojMje0dGBzs7OiOzt7e0R5dnY2AiLxYLi4uKYZb9dOByG3+9P6MwLFixAe3s72tratGnevHlYunSp9u9EzR4Tqvd4R6uurk7MZrPs3LlTzpw5Iy+99JJkZ2dH7P2Pt97eXjlx4oScOHFCAMhvfvMbOXHihHzxxRcicvNQe3Z2tnz44Ydy8uRJeeqpp4Y91P69731Pjhw5Ip9++qnMmDEj5ofaV65cKVarVT755BO5fPmyNl2/fl0bs2LFCiksLJQDBw7IsWPHxOFwiMPh0JYPHfpduHChtLW1yUcffSRTpkyJ6aHfDRs2SHNzs5w/f15OnjwpGzZsEJ1OJ3/7298SNvM3ufVoV7Jlv1dJVz4iIr/73e+ksLBQTCaTPPLII/LZZ58pzXPw4EEBcMdUVVUlIjcPt2/atElsNpuYzWZZsGCBdHR0RDzH1atXZcmSJZKZmSkWi0VeeOEF6e3tjWnu4TIDkB07dmhjbty4IS+//LJMmjRJ0tPT5ZlnnpHLly9HPM+FCxdk0aJFkpaWJpMnT5ZXXnlFgsFgzHL/7Gc/k+985ztiMplkypQpsmDBAq14EjXzN7m9fJIp+73i/XyISImk2udDROMHy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREv8fYIGIuQnHdDwAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(\n",
    "            [action * 2])\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.1246],\n",
       "         [0.2240]], grad_fn=<TanhBackward0>),\n",
       " tensor([[0.8703],\n",
       "         [0.9434]], grad_fn=<ExpBackward0>))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class ModelAction(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.s = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 64),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(64, 64),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "        self.mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "        self.sigma = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.s(state)\n",
    "        return self.mu(state), self.sigma(state).exp()\n",
    "\n",
    "\n",
    "model_action = ModelAction()\n",
    "\n",
    "model_action(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0718],\n",
       "        [0.1226]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_value = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_value_next = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_value_next.load_state_dict(model_value.state_dict())\n",
    "\n",
    "model_value(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/cuda117/lib/python3.10/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
      "  if not isinstance(terminated, (bool, np.bool8)):\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "37.465940609927564"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    data = []\n",
    "    reward_sum = 0\n",
    "\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据概率采样\n",
    "        mu, sigma = model_action(torch.FloatTensor(state).reshape(1, 3))\n",
    "        action = random.normalvariate(mu=mu.item(), sigma=sigma.item())\n",
    "\n",
    "        next_state, reward, over = env.step(action)\n",
    "\n",
    "        data.append((state, action, reward, next_state, over))\n",
    "        reward_sum += reward\n",
    "\n",
    "        state = next_state\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    return data, reward_sum\n",
    "\n",
    "\n",
    "play()[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_27560/3624659836.py:27: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:230.)\n",
      "  state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(200,\n",
       " (array([0.18725726, 0.9823109 , 0.8032599 ], dtype=float32),\n",
       "  2.469941874166773,\n",
       "  0.7525465999398521,\n",
       "  array([0.09622052, 0.99536   , 1.8399931 ], dtype=float32),\n",
       "  False))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据池\n",
    "class Pool:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.pool = []\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.pool)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.pool[i]\n",
    "\n",
    "    #更新动作池\n",
    "    def update(self):\n",
    "        #每次更新不少于N条新数据\n",
    "        old_len = len(self.pool)\n",
    "        while len(pool) - old_len < 200:\n",
    "            self.pool.extend(play()[0])\n",
    "\n",
    "        #只保留最新的N条数据\n",
    "        self.pool = self.pool[-2_0000:]\n",
    "\n",
    "    #获取一批数据样本\n",
    "    def sample(self):\n",
    "        data = random.sample(self.pool, 64)\n",
    "\n",
    "        state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 3)\n",
    "        action = torch.FloatTensor([i[1] for i in data]).reshape(-1, 1)\n",
    "        reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1)\n",
    "        next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, 3)\n",
    "        over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1)\n",
    "\n",
    "        return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "pool = Pool()\n",
    "pool.update()\n",
    "state, action, reward, next_state, over = pool.sample()\n",
    "\n",
    "len(pool), pool[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_action = torch.optim.Adam(model_action.parameters(), lr=5e-4)\n",
    "optimizer_value = torch.optim.Adam(model_value.parameters(), lr=5e-3)\n",
    "\n",
    "\n",
    "def soft_update(_from, _to):\n",
    "    for _from, _to in zip(_from.parameters(), _to.parameters()):\n",
    "        value = _to.data * 0.995 + _from.data * 0.005\n",
    "        _to.data.copy_(value)\n",
    "\n",
    "\n",
    "def get_action_entropy(state):\n",
    "    mu, sigma = model_action(torch.FloatTensor(state).reshape(-1, 3))\n",
    "    dist = torch.distributions.Normal(mu, sigma)\n",
    "\n",
    "    action = dist.rsample()\n",
    "\n",
    "    return action, sigma\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08588205277919769"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_value(state, action, reward, next_state, over):\n",
    "    requires_grad(model_value, True)\n",
    "    requires_grad(model_action, False)\n",
    "\n",
    "    #计算target\n",
    "    with torch.no_grad():\n",
    "        #计算动作和熵\n",
    "        next_action, entropy = get_action_entropy(next_state)\n",
    "\n",
    "        #评估next_state的价值\n",
    "        input = torch.cat([next_state, next_action], dim=1)\n",
    "        target = model_value_next(input)\n",
    "\n",
    "    #加权熵,熵越大越好\n",
    "    target = target + 5e-3 * entropy\n",
    "    target = target * 0.99 * (1 - over) + reward\n",
    "\n",
    "    #计算value\n",
    "    value = model_value(torch.cat([state, action], dim=1))\n",
    "\n",
    "    loss = torch.nn.functional.mse_loss(value, target)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_value.step()\n",
    "    optimizer_value.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_value(state, action, reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.21783263981342316"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_action(state):\n",
    "    requires_grad(model_value, False)\n",
    "    requires_grad(model_action, True)\n",
    "\n",
    "    #计算action和熵\n",
    "    action, entropy = get_action_entropy(state)\n",
    "\n",
    "    #计算value\n",
    "    value = model_value(torch.cat([state, action], dim=1))\n",
    "\n",
    "    #加权熵,熵越大越好\n",
    "    loss = -(value + 5e-3 * entropy).mean()\n",
    "\n",
    "    #使用model_value计算model_action的loss\n",
    "    loss.backward()\n",
    "    optimizer_action.step()\n",
    "    optimizer_action.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_action(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 6.716831857809709\n",
      "10 2400 113.66705195900381\n",
      "20 4400 176.78574515174154\n",
      "30 6400 180.2989842299859\n",
      "40 8400 179.1737776526894\n",
      "50 10400 173.79422176278672\n",
      "60 12400 178.08464277538593\n",
      "70 14400 179.8885112830408\n",
      "80 16400 177.19864279431127\n",
      "90 18400 178.51405497661727\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model_action.train()\n",
    "    model_value.train()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(100):\n",
    "        #更新N条数据\n",
    "        pool.update()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = pool.sample()\n",
    "\n",
    "            #训练\n",
    "            train_value(state, action, reward, next_state, over)\n",
    "            train_action(state)\n",
    "            soft_update(model_value, model_value_next)\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, len(pool), test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbrklEQVR4nO3df2xT97038LftxM7P45CE2M0laSJBSyN+dA0Qziqte0ZG2mVdWdNHrEJd1vK0KjNcKBNasxWqVZPCQ6V1paMwqRpwr9Rmore0K4O2eQINqzC/AlnDj2adRJtcgh1+NHYIxE7sz/MHzbk1pCwOsb+2+35JRyLn+7X9PgG/OT7HPjaJiICIKM7MqgMQ0TcTy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRQVj6bNm1CWVkZMjIyUFVVhcOHD6uKQkQKKCmfP//5z1i9ejWef/55HDt2DLNnz0ZNTQ16e3tVxCEiBUwqPlhaVVWFuXPn4g9/+AMAIBwOo6SkBCtWrMCzzz77L28fDofR09OD3NxcmEymWMclojESEfT396O4uBhm8833bdLilMkQDAbR1taGhoYGY53ZbEZ1dTXcbveotwkEAggEAsbPZ8+eRUVFRcyzEtH4dHd3Y8qUKTedE/fyuXDhAkKhEBwOR8R6h8OBTz75ZNTbNDY24je/+c0N67u7u6FpWkxyElH0/H4/SkpKkJub+y/nxr18xqOhoQGrV682fh7ZQE3TWD5ECWgsh0PiXj6FhYWwWCzwer0R671eL5xO56i3sdlssNls8YhHRHES97NdVqsVlZWVaGlpMdaFw2G0tLRA1/V4xyEiRZS87Fq9ejXq6+sxZ84czJs3D7///e8xMDCAxx9/XEUcIlJASfksXrwY58+fx7p16+DxeHD33Xfjvffeu+EgNBGlLiXv87lVfr8fdrsdPp+PB5yJEkg0z01+touIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlIi6vLZv38/HnzwQRQXF8NkMuHtt9+OGBcRrFu3DrfddhsyMzNRXV2NTz/9NGLOpUuXsGTJEmiahry8PCxduhSXL1++pQ0houQSdfkMDAxg9uzZ2LRp06jjGzZswMaNG7FlyxYcOnQI2dnZqKmpweDgoDFnyZIlOHnyJJqbm7Fr1y7s378fTz311Pi3goiSj9wCALJz507j53A4LE6nU1588UVjXV9fn9hsNnnjjTdEROTUqVMCQI4cOWLM2bNnj5hMJjl79uyYHtfn8wkA8fl8txKfiCZYNM/NCT3mc+bMGXg8HlRXVxvr7HY7qqqq4Ha7AQButxt5eXmYM2eOMae6uhpmsxmHDh0a9X4DgQD8fn/EQkTJbULLx+PxAAAcDkfEeofDYYx5PB4UFRVFjKelpSE/P9+Yc73GxkbY7XZjKSkpmcjYRKRAUpztamhogM/nM5bu7m7VkYjoFk1o+TidTgCA1+uNWO/1eo0xp9OJ3t7eiPHh4WFcunTJmHM9m80GTdMiFiJKbhNaPuXl5XA6nWhpaTHW+f1+HDp0CLquAwB0XUdfXx/a2tqMOXv37kU4HEZVVdVExiGiBJYW7Q0uX76Mf/7zn8bPZ86cQXt7O/Lz81FaWopVq1bht7/9LaZNm4by8nKsXbsWxcXFWLRoEQDgrrvuwv33348nn3wSW7ZswdDQEJYvX46f/OQnKC4unrANI6IEF+2ptH379gmAG5b6+noRuXa6fe3ateJwOMRms8mCBQuks7Mz4j4uXrwojz76qOTk5IimafL4449Lf3//mDPwVDtRYormuWkSEVHYfePi9/tht9vh8/l4/IcogUTz3EyKs11ElHpYPkSkBMuHiJSI+mwX0UQSEUAEwQsXELxwARCBtaAA1smTAbMZJpNJdUSKEZYPKSMiGO7rg+ett9B34ACCly4BIkifNAn2efPgfOQRWAsLWUApiuVDSogIhr/4Ap+98gr8x44BXznpOnTxIi7s2YPB7m6UrVwJa1ERCygF8ZgPqSGCnqamG4rnqy6fOIGz//EfkOHhOIejeGD5kBJXP/8cX3z00dcWzwjf0aMYuO5KmJQaWD4UdyKCCx98gNAYLp0bvnoVQxcuxCEVxRvLh+IvFMLVM2dUpyDFWD4UdxIKIczjON94LB+Ku/DwMA8iE8uH4k9CIcjQkOoYpBjLh+JOgkGEAwHVMUgxlg/F3dAXXyB4/vyY5pozMpBeUBDjRKQCy4fiL4pLSJkzMpCelxe7LKQMy4cSmslshik9XXUMigGWDyU2iwVmlk9KYvlQQjOZzTBZrapjUAywfCjuJBQa83Efk9kMcxovvpCKWD4Ud+HBwbFPNpmuLZRyWD4UdyG+x4fA8iEFotrzoZTF8qG4Y/kQwPIhBQLnzo15rsli4TGfFMXyobgLeL1jnps1bRpMZv4zTUX8W6WEZsnK4p5PimL5UEIzZ2SojkAxwvKhhGax2bjnk6JYPhRXEsUn2gHAbLPFKAmpxvKh+AqHo7qEqokfrUhZLB+KKwmHo7uEqsnEbytNUSwfiisJhRDm9ZsJLB+Kt1AIEgyqTkEJgOVDcSXhMPd8CECU5dPY2Ii5c+ciNzcXRUVFWLRoETo7OyPmDA4OwuVyoaCgADk5Oairq4P3une0dnV1oba2FllZWSgqKsKaNWswzO9x+kYY7u9HwOMZ01yzzYaMf/u3GCciVaIqn9bWVrhcLhw8eBDNzc0YGhrCwoULMTAwYMx55pln8O6772LHjh1obW1FT08PHn74YWM8FAqhtrYWwWAQBw4cwPbt27Ft2zasW7du4raKElcUB5xNFgssOTkxDkSqmCTaN158xfnz51FUVITW1lZ85zvfgc/nw+TJk/H666/jkUceAQB88sknuOuuu+B2uzF//nzs2bMHP/zhD9HT0wOHwwEA2LJlC375y1/i/PnzsI7hkpl+vx92ux0+nw+apo03Pikw+N//jVP//u9jOt1uycnBnevXI7O0NA7JaCJE89y8pWM+Pp8PAJCfnw8AaGtrw9DQEKqrq40506dPR2lpKdxuNwDA7XZj5syZRvEAQE1NDfx+P06ePDnq4wQCAfj9/oiFUp/JbIaZ129OWeMun3A4jFWrVuHee+/FjBkzAAAejwdWqxV5133PksPhgOfL1/kejyeieEbGR8ZG09jYCLvdbiwlJSXjjU3JxGTixeNT2LjLx+Vy4cSJE2hqaprIPKNqaGiAz+czlu7u7pg/JsWGhMNjn8w9n5Q2rveuL1++HLt27cL+/fsxZcoUY73T6UQwGERfX1/E3o/X64XT6TTmHD58OOL+Rs6Gjcy5ns1mg42f8UkJ4cHBsX9zBcAPlaawqPZ8RATLly/Hzp07sXfvXpSXl0eMV1ZWIj09HS0tLca6zs5OdHV1Qdd1AICu6+jo6EBvb68xp7m5GZqmoaKi4la2hZJAOBDAuM9wUEqJas/H5XLh9ddfxzvvvIPc3FzjGI3dbkdmZibsdjuWLl2K1atXIz8/H5qmYcWKFdB1HfPnzwcALFy4EBUVFXjsscewYcMGeDwePPfcc3C5XNy7+QYIBwJRfVc7pa6oymfz5s0AgO9+97sR67du3Yqf/exnAICXXnoJZrMZdXV1CAQCqKmpwauvvmrMtVgs2LVrF5YtWwZd15GdnY36+nq88MILt7YllBRCvHg8fSmq8hnLW4IyMjKwadMmbNq06Wvn3H777di9e3c0D00p4upnnwFjPOhszsjgJ9pTGD/bRXE1HMV7tDLLymBKT49hGlKJ5UMJy8xLqKY0lg8lLLPNxpddKYzlQwnLbLVyzyeFsXwobsZ18XiWT8pi+VD8iPDi8WRg+VD8hMMIR3MJVV48PqWxfChuRITXbyYDy4fiJ9o9H0ppLB+KGwmFELp6VXUMShAsH4qb0NWrGOzqAnDtJVggFMKFwUEMhkI3TrZYYC0qinNCiieeTqC4knAYIoKugQH0XLmCPKsVVrMZGRZLxDyTxQJrYaGilBQP3POhuOseGEBfMIi5hYVwZmbiH34/PvX7Ef7K+4BMJtO19/lQyuKeD8VVMBxGz5UrmFNYiLNXrmDt8ePo9PmQnZaG/3PHHVhcXn7tf0SWT8rjng/F1eWhIeRZrTCbTPi/HR041deHkAj8Q0P4w+nT6Lh0CW0XLyIkwvJJcSwfihtLZiYyy8pgMV/7Z+e/7ssDg+EwLgaDOHj+PAbDYV48PsWxfChuLFlZyJ4+HSERmAD8L6cTaV95B/MdmobpmoZ7CgqQm5PDj1ekOP7tUlyVLVyIg//1XwiJoH7qVOSmp+P/nTuH2zIz8eQdd2BIBLdlZiL/vvuQxm+jTWnc86G40srKMOeJJ3BqYAAA8L/LyrBF1/H83XcjMy0NnitXMH3ePDgWLYLJzH+eqYx7PhRXJrMZ33rkEYQBHP/P/0TmwABMIgiEwwhZLLjvgQcw9amnYP3yK7gpdbF8KO7M6emYs3gx7rj3XvQcOYLAuXOwZWfjtlmzkFtRAUtmpuqIFAcsH1LCZDbDXloKe2mp6iikCF9UE5ESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIloiqfzZs3Y9asWdA0DZqmQdd17NmzxxgfHByEy+VCQUEBcnJyUFdXB6/XG3EfXV1dqK2tRVZWFoqKirBmzRoMDw9PzNYQUdKIqnymTJmC9evXo62tDUePHsX3vvc9PPTQQzh58iQA4JlnnsG7776LHTt2oLW1FT09PXj44YeN24dCIdTW1iIYDOLAgQPYvn07tm3bhnXr1k3sVhFR4pNbNGnSJHnttdekr69P0tPTZceOHcbY6dOnBYC43W4REdm9e7eYzWbxeDzGnM2bN4umaRIIBL72MQYHB8Xn8xlLd3e3ABCfz3er8YloAvl8vjE/N8d9zCcUCqGpqQkDAwPQdR1tbW0YGhpCdXW1MWf69OkoLS2F2+0GALjdbsycORMOh8OYU1NTA7/fb+w9jaaxsRF2u91YSkpKxhubiBJE1OXT0dGBnJwc2Gw2PP3009i5cycqKirg8XhgtVqRl5cXMd/hcMDj8QAAPB5PRPGMjI+MfZ2Ghgb4fD5j6e7ujjY2ESWYqK/hfOedd6K9vR0+nw9vvvkm6uvr0draGotsBpvNBhu/OpcopURdPlarFVOnTgUAVFZW4siRI3j55ZexePFiBINB9PX1Rez9eL1eOJ1OAIDT6cThw4cj7m/kbNjIHCL6Zrjl9/mEw2EEAgFUVlYiPT0dLS0txlhnZye6urqg6zoAQNd1dHR0oLe315jT3NwMTdNQUVFxq1GIKIlEtefT0NCABx54AKWlpejv78frr7+ODz/8EO+//z7sdjuWLl2K1atXIz8/H5qmYcWKFdB1HfPnzwcALFy4EBUVFXjsscewYcMGeDwePPfcc3C5XHxZRfQNE1X59Pb24qc//SnOnTsHu92OWbNm4f3338f3v/99AMBLL70Es9mMuro6BAIB1NTU4NVXXzVub7FYsGvXLixbtgy6riM7Oxv19fV44YUXJnariCjhmUREVIeIlt/vh91uh8/ng6ZpquMQ0ZeieW7ys11EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESt1Q+69evh8lkwqpVq4x1g4ODcLlcKCgoQE5ODurq6uD1eiNu19XVhdraWmRlZaGoqAhr1qzB8PDwrUQhoiQz7vI5cuQI/vjHP2LWrFkR65955hm8++672LFjB1pbW9HT04OHH37YGA+FQqitrUUwGMSBAwewfft2bNu2DevWrRv/VhBR8pFx6O/vl2nTpklzc7Pcd999snLlShER6evrk/T0dNmxY4cx9/Tp0wJA3G63iIjs3r1bzGazeDweY87mzZtF0zQJBAKjPt7g4KD4fD5j6e7uFgDi8/nGE5+IYsTn8435uTmuPR+Xy4Xa2lpUV1dHrG9ra8PQ0FDE+unTp6O0tBRutxsA4Ha7MXPmTDgcDmNOTU0N/H4/Tp48OerjNTY2wm63G0tJScl4YhNRAom6fJqamnDs2DE0NjbeMObxeGC1WpGXlxex3uFwwOPxGHO+Wjwj4yNjo2loaIDP5zOW7u7uaGMTUYJJi2Zyd3c3Vq5ciebmZmRkZMQq0w1sNhtsNlvcHo+IYi+qPZ+2tjb09vbinnvuQVpaGtLS0tDa2oqNGzciLS0NDocDwWAQfX19Ebfzer1wOp0AAKfTecPZr5GfR+YQUeqLqnwWLFiAjo4OtLe3G8ucOXOwZMkS48/p6eloaWkxbtPZ2Ymuri7oug4A0HUdHR0d6O3tNeY0NzdD0zRUVFRM0GYRUaKL6mVXbm4uZsyYEbEuOzsbBQUFxvqlS5di9erVyM/Ph6ZpWLFiBXRdx/z58wEACxcuREVFBR577DFs2LABHo8Hzz33HFwuF19aEX2DRFU+Y/HSSy/BbDajrq4OgUAANTU1ePXVV41xi8WCXbt2YdmyZdB1HdnZ2aivr8cLL7ww0VGIKIGZRERUh4iW3++H3W6Hz+eDpmmq4xDRl6J5bvKzXUSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqkqQ4wHiICAPD7/YqTENFXjTwnR56jN5OU5XPx4kUAQElJieIkRDSa/v5+2O32m85JyvLJz88HAHR1df3LDUw0fr8fJSUl6O7uhqZpquOMGXPHV7LmFhH09/ejuLj4X85NyvIxm68dqrLb7Un1F/NVmqYlZXbmjq9kzD3WHQIecCYiJVg+RKREUpaPzWbD888/D5vNpjpK1JI1O3PHV7LmjoZJxnJOjIhogiXlng8RJT+WDxEpwfIhIiVYPkSkBMuHiJRIyvLZtGkTysrKkJGRgaqqKhw+fFhpnv379+PBBx9EcXExTCYT3n777YhxEcG6detw2223ITMzE9XV1fj0008j5ly6dAlLliyBpmnIy8vD0qVLcfny5ZjmbmxsxNy5c5Gbm4uioiIsWrQInZ2dEXMGBwfhcrlQUFCAnJwc1NXVwev1Rszp6upCbW0tsrKyUFRUhDVr1mB4eDhmuTdv3oxZs2YZ7/7VdR179uxJ6MyjWb9+PUwmE1atWpV02SeEJJmmpiaxWq3ypz/9SU6ePClPPvmk5OXlidfrVZZp9+7d8utf/1reeustASA7d+6MGF+/fr3Y7XZ5++235e9//7v86Ec/kvLycrl69aox5/7775fZs2fLwYMH5W9/+5tMnTpVHn300Zjmrqmpka1bt8qJEyekvb1dfvCDH0hpaalcvnzZmPP0009LSUmJtLS0yNGjR2X+/Pny7W9/2xgfHh6WGTNmSHV1tRw/flx2794thYWF0tDQELPcf/nLX+Svf/2r/OMf/5DOzk751a9+Jenp6XLixImEzXy9w4cPS1lZmcyaNUtWrlxprE+G7BMl6cpn3rx54nK5jJ9DoZAUFxdLY2OjwlT/4/ryCYfD4nQ65cUXXzTW9fX1ic1mkzfeeENERE6dOiUA5MiRI8acPXv2iMlkkrNnz8Yte29vrwCQ1tZWI2d6errs2LHDmHP69GkBIG63W0SuFa/ZbBaPx2PM2bx5s2iaJoFAIG7ZJ02aJK+99lpSZO7v75dp06ZJc3Oz3HfffUb5JEP2iZRUL7uCwSDa2tpQXV1trDObzaiurobb7VaY7OudOXMGHo8nIrPdbkdVVZWR2e12Iy8vD3PmzDHmVFdXw2w249ChQ3HL6vP5APzPVQPa2towNDQUkX369OkoLS2NyD5z5kw4HA5jTk1NDfx+P06ePBnzzKFQCE1NTRgYGICu60mR2eVyoba2NiIjkBy/74mUVJ9qv3DhAkKhUMQvHgAcDgc++eQTRaluzuPxAMComUfGPB4PioqKIsbT0tKQn59vzIm1cDiMVatW4d5778WMGTOMXFarFXl5eTfNPtq2jYzFSkdHB3Rdx+DgIHJycrBz505UVFSgvb09YTMDQFNTE44dO4YjR47cMJbIv+9YSKryodhxuVw4ceIEPvroI9VRxuTOO+9Ee3s7fD4f3nzzTdTX16O1tVV1rJvq7u7GypUr0dzcjIyMDNVxlEuql12FhYWwWCw3HP33er1wOp2KUt3cSK6bZXY6nejt7Y0YHx4exqVLl+KyXcuXL8euXbuwb98+TJkyxVjvdDoRDAbR19d30+yjbdvIWKxYrVZMnToVlZWVaGxsxOzZs/Hyyy8ndOa2tjb09vbinnvuQVpaGtLS0tDa2oqNGzciLS0NDocjYbPHQlKVj9VqRWVlJVpaWox14XAYLS0t0HVdYbKvV15eDqfTGZHZ7/fj0KFDRmZd19HX14e2tjZjzt69exEOh1FVVRWzbCKC5cuXY+fOndi7dy/Ky8sjxisrK5Genh6RvbOzE11dXRHZOzo6IsqzubkZmqahoqIiZtmvFw6HEQgEEjrzggUL0NHRgfb2dmOZM2cOlixZYvw5UbPHhOoj3tFqamoSm80m27Ztk1OnTslTTz0leXl5EUf/462/v1+OHz8ux48fFwDyu9/9To4fPy6ff/65iFw71Z6XlyfvvPOOfPzxx/LQQw+Neqr9W9/6lhw6dEg++ugjmTZtWsxPtS9btkzsdrt8+OGHcu7cOWO5cuWKMefpp5+W0tJS2bt3rxw9elR0XRdd143xkVO/CxculPb2dnnvvfdk8uTJMT31++yzz0pra6ucOXNGPv74Y3n22WfFZDLJBx98kLCZv85Xz3YlW/ZblXTlIyLyyiuvSGlpqVitVpk3b54cPHhQaZ59+/YJgBuW+vp6Ebl2un3t2rXicDjEZrPJggULpLOzM+I+Ll68KI8++qjk5OSIpmny+OOPS39/f0xzj5YZgGzdutWYc/XqVfn5z38ukyZNkqysLPnxj38s586di7ifzz77TB544AHJzMyUwsJC+cUvfiFDQ0Mxy/3EE0/I7bffLlarVSZPniwLFiwwiidRM3+d68snmbLfKl7Ph4iUSKpjPkSUOlg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJT4/9+IVTx/KZPZAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "184.1999658049721"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:cuda117]",
   "language": "python",
   "name": "conda-env-cuda117-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
